/*
 Copyright 2014-Present Algorithms in Motion LLC
 
 This file is part of FDTD++.
 
 FDTD++ is proprietary software: you can use it and/or modify it
 under the terms of the Algorithms in Motion License as published by
 Algorithms in Motion LLC, either version 1 of the License, or (at your
 option) any later version.
 
 FDTD++ is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 Algorithms in Motion License for more details.
 
 You should have received a copy of the Algorithms in Motion License
 along with FDTD++. If not, see <http://www.aimotionllc.com/licenses/>.
*/
// CREATED    : 9/7/2015
// LAST UPDATE: 9/30/2015

#include "statsxx/machine_learning/restricted_Boltzmann_machine/RBM.hpp"

// STL
#include <numeric> // std::iota()
#include <tuple>   // std::tuple<>, std::make_tuple(), std::tie()
#include <vector>  // std::vector<>

// jScience
#include "jScience/linalg.hpp"      // Vector<>, Matrix<>, 
#include "jScience/stats/stats.hpp" // shuffle_data()

// stats++
#include "statsxx/distribution.hpp" // distribution::binomial()
#include "statsxx/machine_learning/activation_functions.hpp" // activation_function::Logistic


// basic algorithm for stochastically estimating the gradient of the log likelihood function -- see p. 117 of "DBN in C++ ..."
inline std::tuple<
                  Matrix<double>,
                  Vector<double>,
                  Vector<double>,
                  Vector<double>,
                  double
                  > restricted_Boltzmann_machine::RBM::grad_log_likelihood(
                                                                           const int K,               // number of Monte Carlo itertions
                                                                           const bool mean_field,     // mean-field approximation
                                                                           const Vector<double> &x,   // data
                                                                           const bool x_binomial      // whether the data is binomial
                                                                           )
{
    if(K < 1)
    {
        throw std::runtime_error("Error in grad_log_likelihodd(): K < 1");
    }
    
    Matrix<double> W_grad;
    Vector<double> b_grad;
    Vector<double> c_grad;
    
    // << jmm: the following incurs extra cost, because a full reconstruction of the data is done to determine the error, yet p_data is not needed again -- see note below >>
    Vector<double> p_data;
    Vector<double> q_data;
    double error;
    std::tie(p_data, q_data, error) = this->reconstruct(x, x_binomial);
    
    // reconstruction probabilities under the _model_ distribution that ...
    Vector<double> p_model;          // each visible neuron will be one 
    Vector<double> q_model = q_data; // each hidden neuron will be one     
    Vector<double> v_model;
    
    activation_function::Logistic logistic;
    
    for(int k = 0; k < K; ++k)
    {
        Vector<double> h_model = distribution::binomial<double>(1, q_model);
        
        p_model = logistic.f(b + transpose(W)*h_model);
        
        // if k = 0, optionally reconstruct the error using the fast method
        
        if(mean_field)
        {
            q_model = logistic.f(c + W*p_model);
        }
        else
        {
            v_model = distribution::binomial<double>(1, p_model);
            q_model = logistic.f(c + W*v_model);
        }
    }
    
    if(mean_field)
    {
        b_grad = p_model - x;
        
        c_grad = q_model - q_data;
        
        W_grad = outer_product(q_model, p_model) - outer_product(q_data, x);
    }
    else
    {
        b_grad = v_model - x;
        
        Vector<double> h_data = distribution::binomial<double>(1, q_data);
        c_grad = q_model - h_data;
        
        W_grad = outer_product(q_model, v_model) - outer_product(h_data, x);
    }
    
    return std::make_tuple(W_grad, b_grad, c_grad, q_data, error);
}
